package vfile.translate;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;

import java.util.Map;

public class PtKTranslator implements Translator<Image, NDList> {

    private final int maxLength;

    /**
     * Creates the {@link } translator.
     *
     * @param arguments the arguments for the translator
     */
    public PtKTranslator(Map<String, ?> arguments) {
        maxLength =
                arguments.containsKey("maxLength")
                        ? Integer.parseInt(arguments.get("maxLength").toString())
                        : 960;
    }
    /**
     * {@inheritDoc}
     */
    @Override
    public NDList processInput(TranslatorContext ctx, Image input) {
        NDArray img = input.toNDArray(ctx.getNDManager());
       
        Shape input_shape = img.getShape();
        //output_shape 255,255
        int output_ndim = 2; //默认
        NDArray tmpIn = ctx.getNDManager().create(input_shape, DataType.FLOAT64);
        NDArray tmpout= ctx.getNDManager().create(input_shape, DataType.FLOAT64);
        NDArray factors = tmpIn.div(tmpout);
        
        int rows = 255;
        int cols = 255;
        int[][] src = {{1, 1}, {1, rows}, {cols, rows}};
        NDArray src_corners =ctx.getNDManager().create(src);
        NDArray dst_corners = ctx.getNDManager().zeros(src_corners.getShape(),DataType.FLOAT64);
        
        dst_corners.set(new NDIndex(":, 0"),factors.get(1).mul(src_corners.get(":, 0").add(0.5)).sub(0.5));
        dst_corners.set(new NDIndex(":, 1"),factors.get(0).mul(src_corners.get(":, 1").add(0.5)).sub(0.5));
       
             //   np.asarray(output_shape, dtype=float))
        // int h = input.getHeight();
        //int w = input.getWidth();
        //int[] hw = scale(h, w, maxLength);

        img = NDImageUtils.resize(img, 256,256);
        img = NDImageUtils.toTensor(img);
        img =
                NDImageUtils.normalize(
                        img,
                        new float[]{0.485f, 0.456f, 0.406f},
                        new float[]{0.229f, 0.224f, 0.225f});
        img = img.expandDims(0);
        return new NDList(img);
    }
    /**
     * {@inheritDoc}
     */
    @Override
    public NDList processOutput(TranslatorContext ctx, NDList list) {
        //(1, 3, 512, 960)
        //NDArray img = list.singletonOrThrow();
        System.out.println(list);
        return new NDList(list);
        /*//ND: (1, 3, 512, 960) cpu() float32
        NDArray result = list.singletonOrThrow();
        //ND: (3, 512, 960) cpu() boolean
        result = result.squeeze().mul(255f).toType(DataType.UINT8, true).neq(0);
        boolean[] flattened = result.toBooleanArray();
        //(3, 512, 960)
        Shape shape = result.getShape();
        int c = (int) shape.get(0);
        int w = (int) shape.get(1) * c;
        int h = (int) shape.get(2) * c;
        System.out.println("flattened:"+flattened.length +" w:"+w + " h:"+h +" c:"+c);
        boolean[][] grid = new boolean[w][h];
        IntStream.range(0, flattened.length)
                .parallel()
                .forEach(i -> grid[i / h][i % h] = flattened[i]);
        List<BoundingBox> boxes = new BoundFinder(grid).getBoxes();
        List<String> names = new ArrayList<>();
        List<Double> probs = new ArrayList<>();
        int boxSize = boxes.size();
        for (int i = 0; i < boxSize; i++) {
            names.add("word");
            probs.add(1.0);
        }
        return new DetectedObjects(names, probs, boxes);*/
    }



    /**
     * {@inheritDoc}
     */
    @Override
    public Batchifier getBatchifier() {
        return null;
    }

    private int[] scale(int h, int w, int max) {
        int localMax = Math.max(h, w);
        float scale = 1.0f;
        if (max < localMax) {
            scale = max * 1.0f / localMax;
        }
        // paddle model only take 32-based size
        return resize32(h * scale, w * scale);
    }

    private int[] resize32(double h, double w) {
        double min = Math.min(h, w);
        if (min < 32) {
            h = 32.0 / min * h;
            w = 32.0 / min * w;
        }
        int h32 = (int) h / 32;
        int w32 = (int) w / 32;
        return new int[]{h32 * 32, w32 * 32};
    }
}
